{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Unsupervised learning\n",
    "\n",
    "When we do error-modulated learning with the `nengo.PES` rule,\n",
    "we have a pretty clear idea of what we want to happen.\n",
    "Years of neuroscientific experiments have yielded\n",
    "learning rules explaining how synaptic strengths\n",
    "change given certain stimulation protocols.\n",
    "But what do these learning rules actually do\n",
    "to the information transmitted across an\n",
    "ensemble-to-ensemble connection?\n",
    "\n",
    "We can investigate this in Nengo.\n",
    "Currently, we've implemented the `nengo.BCM`\n",
    "and `nengo.Oja` rules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import nengo"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(nengo.BCM.__doc__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(nengo.Oja.__doc__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Create a simple communication channel\n",
    "\n",
    "The only difference between this network and most\n",
    "models you've seen so far is that we're going to\n",
    "set the decoder solver in the communication channel\n",
    "to generate a full connection weight matrix\n",
    "which we can then learn using typical delta learning rules."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = nengo.Network()\n",
    "with model:\n",
    "    sin = nengo.Node(lambda t: np.sin(t * 4))\n",
    "\n",
    "    pre = nengo.Ensemble(100, dimensions=1)\n",
    "    post = nengo.Ensemble(100, dimensions=1)\n",
    "\n",
    "    nengo.Connection(sin, pre)\n",
    "    conn = nengo.Connection(\n",
    "        pre, post, solver=nengo.solvers.LstsqL2(weights=True))\n",
    "\n",
    "    pre_p = nengo.Probe(pre, synapse=0.01)\n",
    "    post_p = nengo.Probe(post, synapse=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Verify that it does a communication channel\n",
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(2.0)\n",
    "\n",
    "plt.figure()\n",
    "plt.plot(sim.trange(), sim.data[pre_p], label=\"Pre\")\n",
    "plt.plot(sim.trange(), sim.data[post_p], label=\"Post\")\n",
    "plt.ylabel(\"Decoded value\")\n",
    "plt.legend(loc=\"best\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What does BCM do?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conn.learning_rule_type = nengo.BCM(learning_rate=5e-10)\n",
    "with model:\n",
    "    weights_p = nengo.Probe(conn, 'weights', synapse=0.01, sample_every=0.01)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(20.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 8))\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.plot(sim.trange(), sim.data[pre_p], label=\"Pre\")\n",
    "plt.plot(sim.trange(), sim.data[post_p], label=\"Post\")\n",
    "plt.ylabel(\"Decoded value\")\n",
    "plt.ylim(-1.6, 1.6)\n",
    "plt.legend(loc=\"lower left\")\n",
    "plt.subplot(2, 1, 2)\n",
    "# Find weight row with max variance\n",
    "neuron = np.argmax(np.mean(np.var(sim.data[weights_p], axis=0), axis=1))\n",
    "plt.plot(sim.trange(sample_every=0.01), sim.data[weights_p][..., neuron])\n",
    "plt.ylabel(\"Connection weight\");"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The BCM rule appears to cause the ensemble\n",
    "to either be on or off.\n",
    "This seems consistent with the idea that it potentiates\n",
    "active synapses, and depresses non-active synapses.\n",
    "\n",
    "As well, we can show that BCM sparsifies the weights.\n",
    "The sparsity measure below uses the Gini measure of sparsity,\n",
    "for reasons explained [in this paper](https://arxiv.org/pdf/0811.4706.pdf)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sparsity_measure(vector):  # Gini index\n",
    "    # Max sparsity = 1 (single 1 in the vector)\n",
    "    v = np.sort(np.abs(vector))\n",
    "    n = v.shape[0]\n",
    "    k = np.arange(n) + 1\n",
    "    l1norm = np.sum(v)\n",
    "    summation = np.sum((v / l1norm) * ((n - k + 0.5) / n))\n",
    "    return 1 - 2 * summation\n",
    "\n",
    "\n",
    "print(\n",
    "    \"Starting sparsity: {0}\".format(sparsity_measure(sim.data[weights_p][0])))\n",
    "print(\"Ending sparsity: {0}\".format(sparsity_measure(sim.data[weights_p][-1])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What does Oja do?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conn.learning_rule_type = nengo.Oja(learning_rate=6e-8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(20.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 8))\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.plot(sim.trange(), sim.data[pre_p], label=\"Pre\")\n",
    "plt.plot(sim.trange(), sim.data[post_p], label=\"Post\")\n",
    "plt.ylabel(\"Decoded value\")\n",
    "plt.ylim(-1.6, 1.6)\n",
    "plt.legend(loc=\"lower left\")\n",
    "plt.subplot(2, 1, 2)\n",
    "# Find weight row with max variance\n",
    "neuron = np.argmax(np.mean(np.var(sim.data[weights_p], axis=0), axis=1))\n",
    "plt.plot(sim.trange(sample_every=0.01), sim.data[weights_p][..., neuron])\n",
    "plt.ylabel(\"Connection weight\")\n",
    "print(\n",
    "    \"Starting sparsity: {0}\".format(sparsity_measure(sim.data[weights_p][0])))\n",
    "print(\"Ending sparsity: {0}\".format(sparsity_measure(sim.data[weights_p][-1])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The Oja rule seems to attenuate the signal across the connection."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## What do BCM and Oja together do?\n",
    "\n",
    "We can apply both learning rules to the same connection\n",
    "by passing a list to `learning_rule_type`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conn.learning_rule_type = [\n",
    "    nengo.BCM(learning_rate=5e-10), nengo.Oja(learning_rate=2e-9)\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with nengo.Simulator(model) as sim:\n",
    "    sim.run(20.0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(12, 8))\n",
    "plt.subplot(2, 1, 1)\n",
    "plt.plot(sim.trange(), sim.data[pre_p], label=\"Pre\")\n",
    "plt.plot(sim.trange(), sim.data[post_p], label=\"Post\")\n",
    "plt.ylabel(\"Decoded value\")\n",
    "plt.ylim(-1.6, 1.6)\n",
    "plt.legend(loc=\"lower left\")\n",
    "plt.subplot(2, 1, 2)\n",
    "# Find weight row with max variance\n",
    "neuron = np.argmax(np.mean(np.var(sim.data[weights_p], axis=0), axis=1))\n",
    "plt.plot(sim.trange(sample_every=0.01), sim.data[weights_p][..., neuron])\n",
    "plt.ylabel(\"Connection weight\")\n",
    "print(\n",
    "    \"Starting sparsity: {0}\".format(sparsity_measure(sim.data[weights_p][0])))\n",
    "print(\"Ending sparsity: {0}\".format(sparsity_measure(sim.data[weights_p][-1])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The combination of the two appears to accentuate\n",
    "the on-off nature of the BCM rule.\n",
    "As the Oja rule enforces a soft upper or lower bound\n",
    "for the connection weight, the combination\n",
    "of both rules may be more stable than BCM alone.\n",
    "\n",
    "That's mostly conjecture, however;\n",
    "a thorough investigation of the BCM and Oja rules\n",
    "and how they interact has not yet been done."
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
